#!/usr/bin/python

import json
import random
import pickle
import time
import math
import copy
import os.path
import pandas as pd
import numpy as np
import argparse

from data.data import Data
from node.node import Node

def tree_pprinter(node):
    fmtr = ""

    def pprinter(anode):
        nonlocal fmtr

        if anode.split_var is not None:
            print("({}) {} {}".format(len(anode.data.df.index), anode.split_var,
                                       anode.data.var_desc[anode.split_var]['bounds']))

        else:
            print("({}) Leaf {} {}".format(len(anode.data.df.index),
                                    np.var(anode.data.df[anode.data.class_var].values),
                                           ["{} {}".format(key, anode.data.var_desc[key]['bounds']) for key in anode.data.var_desc.keys() if anode.data.var_desc[key]['bounds'] != [[-np.inf, np.inf]]]))

        if anode.left_child is not None:
            print("{} `--".format(fmtr), end="")
            fmtr += "  | "
            pprinter(anode.left_child)
            fmtr = fmtr[:-4]

            print("{} `--".format(fmtr), end="")
            fmtr += "    "
            pprinter(anode.right_child)
            fmtr = fmtr[:-4]

    return pprinter(node)


def tree_trainer(df, class_var, var_desc, stop=50, variance=.01):
        if class_var not in df.columns:
            raise Exception('Class variable not in DataFrame')
    
        data = Data(df, class_var, var_desc)

        node = Node(data, stop=stop, variance=variance)
        node.split()

        return node


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='AeroMetTree model training command line utility.')
    parser.add_argument('--data', dest='data_path', help='Path to the CSV file containing the data to train the tree', required=True)
    parser.add_argument('--config', dest='config_path', help='Path to the JSON file containing the parameters to set model', required=True)

    args = parser.parse_args()

    df = pd.read_csv(args.data_path).dropna()
    model_name = os.path.splitext(os.path.basename(args.config_path))[0] + ".save"

    if os.path.isfile(model_name):
        tree = pickle.load(open(model_name, "rb"))
        tree_pprinter(tree)
        exit(1)

    df = pd.read_csv(args.data_path)

    with open(args.config_path) as conf_file:    
        tree_params = json.load(conf_file)
        tree_desc = {}
        for var in tree_params['input']:
            if not tree_params['contiguous_splits'] and var['type'] == "cir":
                tree_desc[var['name']] = {"type": var['type'], "method": "non-cont", "bounds": [[-np.inf, np.inf]]}
            else:
                tree_desc[var['name']] = {"type": var['type'], "method": "cont", "bounds": [[-np.inf, np.inf]]}
        print(tree_desc)
        tree = tree_trainer(df, tree_params['output']['name'], tree_desc, tree_params['max_leaf_size'])
        pickle.dump(tree, open(model_name, "wb"))
